package com.plexpt.chatgpt;

import com.plexpt.chatgpt.entity.chat.Message;
import com.plexpt.chatgpt.listener.ConsoleStreamListener;
import com.plexpt.chatgpt.util.Proxys;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.math.BigDecimal;
import java.net.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.stream.Collectors;

import cn.hutool.core.util.NumberUtil;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;


/**
 * 控制台-流式传输demo
 *
 * @author jiaqi
 */

@Slf4j
public class ConsoleChatGPT {

    public static Proxy proxy = Proxy.NO_PROXY;

    public static void main(String[] args) {

        List<Message> beforeMessages = new FixSizeLinkedList<>(1);


        System.out.println("ChatGPT - Java command-line interface");
        System.out.println("Press enter twice to submit your question.");
        System.out.println();
        System.out.println("按两次回车以提交您的问题！！！");
        System.out.println("按两次回车以提交您的问题！！！");
        System.out.println("按两次回车以提交您的问题！！！");


        System.out.println();
        System.out.println("Please enter APIKEY, press Enter twice to submit:");
        String key = getInput("请输入APIKEY，按两次回车以提交:\n");
        check(key);

        // 询问用户是否使用代理  国内需要代理
        System.out.println("是否使用代理？(y/n): ");
        System.out.println("use proxy？(y/n): ");
        String useProxy = getInput("按两次回车以提交:\n");
        if (useProxy.equalsIgnoreCase("y")) {

            // 输入代理地址
            System.out.println("请输入代理类型(http/socks): ");
            String type = getInput("按两次回车以提交:\n");

            // 输入代理地址
            System.out.println("请输入代理IP: ");
            String proxyHost = getInput("按两次回车以提交:\n");

            // 输入代理端口
            System.out.println("请输入代理端口: ");
            String portStr = getInput("按两次回车以提交:\n");
            Integer proxyPort = Integer.parseInt(portStr);

            if (type.equals("http")) {
                proxy = Proxys.http(proxyHost, proxyPort);
            } else {
                proxy = Proxys.socks5(proxyHost, proxyPort);
            }

        }

        // System.out.println("Inquiry balance...");
        // System.out.println("查询余额中...");
        // BigDecimal balance = getBalance(key);
        // System.out.println("API KEY balance: " + balance.toPlainString());
        //
        // if (!NumberUtil.isGreater(balance, BigDecimal.ZERO)) {
        //     System.out.println("API KEY 余额不足: ");
        //     return;
        // }


        //我的名字
        String myName = "马克";
        //我的爱好
        String myHobby = "音乐";
        //我的工作
        String myJob = "程序员";
        //宠物名称
        String petName = "二哈";
        //宠物品种
        String petBreed = "哈士奇";
        //宠物的爱好
        String petHobby = "拆家";
        //宠物性格
        String petCharacter = "粘人并且温顺";
        //两者关系程度
        String intimacy = "非常好";

        Message system = Message.ofSystem("你现在不是ai智能，你是一只无所不能的"+petBreed+"，你的名字叫"+petName+"，喜欢"+petHobby+"，" +
                "你的性格"+petCharacter+"，我叫"+myName+"，我是你的主人，我喜欢"+myHobby+"，我的工作是"+myJob+"，我们的关系"+intimacy);

        while (true) {
            String prompt = getInput("\nYou:\n");

            ChatGPTStream chatGPT = ChatGPTStream.builder()
                    .apiKey(key)
                    .proxy(proxy)
                    .build()
                    .init();

            System.out.println("AI: ");


            //卡住
            CountDownLatch countDownLatch = new CountDownLatch(1);

            Message message = Message.of(petName+"，"+prompt);

            ConsoleStreamListener listener = new ConsoleStreamListener() {
                @Override
                public void onError(Throwable throwable, String response) {
                    throwable.printStackTrace();
                    countDownLatch.countDown();
                }
            };
            beforeMessages.add(message);
            listener.setOnComplate(msg -> {

                beforeMessages.add(Message.ofAssistant(msg));
                countDownLatch.countDown();
            });
            List<Message> messages = new ArrayList<>();
            messages.add(system);
            messages.addAll(beforeMessages);

            chatGPT.streamChatCompletion(messages, listener);

            try {
                countDownLatch.await();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }

        }


    }

    private static BigDecimal getBalance(String key) {

        ChatGPT chatGPT = ChatGPT.builder()
                .apiKey(key)
                .proxy(proxy)
                .build()
                .init();

        return chatGPT.balance();
    }

    private static void check(String key) {
        if (key == null || key.isEmpty()) {
            throw new RuntimeException("请输入正确的KEY");
        }
    }

    @SneakyThrows
    public static String getInput(String prompt) {
        System.out.print(prompt);
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
        List<String> lines = new ArrayList<>();
        String line;
        try {
            while ((line = reader.readLine()) != null && !line.isEmpty()) {
                lines.add(line);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return lines.stream().collect(Collectors.joining("\n"));
    }



    static class FixSizeLinkedList<T> extends LinkedList<T> {
        private static final long serialVersionUID = 3292612616231532364L;
        // 定义缓存的容量
        private int capacity;

        public FixSizeLinkedList(int capacity) {
            super();
            this.capacity = capacity;
        }

        @Override
        public boolean add(T e) {
            // 超过长度，移除最后一个
            if (size() + 1 > capacity) {
                super.removeFirst();
            }
            return super.add(e);
        }


    }
}

